Closed
Conversation
Signed-off-by: qiyuw <qiyuw@nvidia.com>
Fix partial amax kernel and Add clear transpose cache See merge request qiyuw/transformerengine!1
Fix partial cast numerical bug See merge request qiyuw/transformerengine!2
added 16 commits
January 5, 2026 19:05
This reverts commit 3c4adf4.
for more information, see https://pre-commit.ci
Contributor
Greptile SummaryThis PR implements NVFP4 partial cast infrastructure for distributed training with ZeRO/FSDP optimizers, enabling efficient FP32→NVFP4 conversion of weight shards with coordinated scaling across data parallel ranks. Key additions:
Critical issue:
Testing:
Performance impact:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Master Weights FP32<br/>Sharded across DP ranks] --> B[Batched dtype conversion<br/>torch.cat + to + split]
B --> C[Multi-tensor partial amax<br/>nvfp4_multi_tensor_compute_partial_amax]
C --> D[Per-block amax<br/>16x16 tiles]
C --> E[Global amax<br/>per tensor]
D --> F[AllReduce MAX<br/>block amax across DP]
E --> G[AllReduce MAX<br/>global amax across DP]
F --> H[Fused scale kernel<br/>nvfp4_fused_scale]
G --> H
H --> I[Compute per-block decode scale<br/>block_amax * 448 / global_amax]
H --> J[Expand to row-level<br/>Convert to FP8 E4M3]
H --> K[Copy global amax to target]
I --> L[Multi-tensor partial cast<br/>nvfp4_multi_tensor_2d_partial_cast]
J --> L
L --> M[NVFP4 packed data<br/>2 nibbles per byte<br/>nibble-accurate updates]
M --> N[AllGather<br/>Gather full model weights]
N --> O[Multi-tensor columnwise creation<br/>nvfp4_multi_tensor_create_columnwise]
O --> P[NVFP4 transpose<br/>Nibble repacking]
O --> Q[Scale transpose<br/>Rowwise to columnwise]
P --> R[Ready for GEMM<br/>Columnwise data + scales]
Q --> R
Last reviewed commit: 687c8b6 |
| elif isinstance(tensor, NVFP4Tensor): | ||
| old_rowwise = tensor._rowwise_data | ||
| assert old_rowwise.dtype == new_raw_data.dtype, "The data types of raw data don't match" | ||
| new_rowwise_data.detach().copy_(old_rowwise) |
Contributor
There was a problem hiding this comment.
new_rowwise_data is undefined.
Suggested change
| new_rowwise_data.detach().copy_(old_rowwise) | |
| new_raw_data.detach().copy_(old_rowwise) |
Author
|
Hard to add signoff for previous commits. Reopened a new PR: #2691 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds NVFP4 partial cast support for distributed training with ZeRO/FSDP optimizers. It enables efficient casting of FP32 master weight shards to NVFP4 model weights with coordinated scaling across data parallel ranks, while minimizing CPU overhead in large-scale training.
Type of change
Changes
This PR introduces NVFP4 partial cast infrastructure and optimizations for distributed training:
NVFP4 Partial Cast Kernel (
nvfp4_2d_partial_cast)NVFP4 Transpose Kernel (
nvfp4_transpose)uint2loads/stores with 64×64 tiles for efficient memory accessFused Scale Kernel (
nvfp4_fused_scale)Multi-Tensor Dispatch Pattern
CPU Overhead Optimizations
torch.cat/torch.splittorch.zeros()withtorch.empty()for immediately written buffersScale Computation Improvements
New Public API
cast_master_weights_to_nvfp4()Testing
test_nvfp4_transpose_kerneltest_nvfp4_partial_cast_matches_fulltest_single_gpu_partial_cast_vs_full_test_cast_master_weights_to_nvfp4This feature also passed numeric validation in GPT-3 training on the corresponding Megatron-Core branch:
https://gitlab-master.nvidia.com/qiyuw/megatron-lm-all/-/tree/fp4_primary_opt?ref_type=heads
Checklist: